// SPDX-License-Identifier: GPL-3.0-or-later
// Copyright © 2018-2019 Ariadne Devos
/* sHT -- IO on stream sockets */

#include "fd.h"
#include <sHT/bugs.h>
#include <sHT/compiler.h>
#include <sHT/nospec.h>
#include <sHT/stream.h>
#include <sHT/test.h>
#include "worker.h"

#include <errno.h>
#include <stddef.h>
#include <stdint.h>
#include <sys/epoll.h>
#include <sys/socket.h>

/* These are sorted in order of expected prevalence */
enum sHT_send_err_type {
        sHT_SEND_BLOCKING,
        sHT_SEND_GRACEFUL_RESET,
        sHT_SEND_BLUNT_RESET,
        sHT_SEND_GRACEFUL_CLOSE,
        /* like in EINTR */
        sHT_SEND_INTERRUPTED,
        /* e.g. timeout, connection reset */
        sHT_SEND_TRANSIENT,
        sHT_SEND_KERNEL_OOM,
        /* Anything we didn't expect.
           (We expect malicious clients,
           but no sHT bugs.)

           This should be logged as a warning. */
        sHT_SEND_OTHER,
};

/* TODO: due to Spectre mitigations interfering with optimisation,
   inline into @var{sHT_socker_sendrecv_errno}. */
__attribute__((const))
static enum sHT_send_err_type
sHT_classify_sentrecv_tcp(int err)
{
	switch (err) {
#if EWOULDBLOCK != EAGAIN
	case EWOULDBLOCK: /* fallthrough */
#endif
	case EAGAIN:
		return sHT_SEND_BLOCKING;
	case EINTR:
		return sHT_SEND_INTERRUPTED;
	case ECONNRESET:
		return sHT_SEND_GRACEFUL_RESET;
	case EPIPE:
		return sHT_SEND_GRACEFUL_CLOSE;
	case ETIMEDOUT: /* fallthrough */
	case EHOSTUNREACH:
		return sHT_SEND_BLUNT_RESET;
	case ENOBUFS:
		/* no busy loops? */
		return sHT_SEND_TRANSIENT;
	case ENOMEM:
		return sHT_SEND_KERNEL_OOM;
	default:
		return sHT_SEND_OTHER;
	}
}

/* True if it should be retried directly, false otherwise. */
__attribute__((nonnull (1, 2)))
static _Bool
sHT_socket_sendrecv_errno(struct sHT_worker *worker, struct sHT_task_stream *task, int err, uint32_t epollflags)
{
	/* XXX use err, not errno */
	switch (sHT_classify_sentrecv_tcp(errno)) {
	case sHT_SEND_BLOCKING:
		task->task.epollflags &= ~epollflags;
		return 0;
	case sHT_SEND_GRACEFUL_CLOSE:
		task->stream.flags |= sHT_STREAM_WRITE_EOF;
		task->task.epollflags &= ~epollflags;
		return 0;
	case sHT_SEND_GRACEFUL_RESET:
		task->stream.flags |= sHT_STREAM_WRITE_EOF | sHT_STREAM_READ_EOF | sHT_STREAM_RESET_GRACEFUL;
		task->task.epollflags &= ~(EPOLLIN | EPOLLOUT);
		return 0;
	case sHT_SEND_BLUNT_RESET:
		task->stream.flags |= sHT_STREAM_WRITE_EOF | sHT_STREAM_READ_EOF | sHT_STREAM_RESET_BLUNT;
		task->task.epollflags &= ~(EPOLLIN | EPOLLOUT);
		return 0;
	case sHT_SEND_INTERRUPTED:
		return 1;
	case sHT_SEND_TRANSIENT:
		/* TODO: may be a good idea to log these too,
		   as an informational message */
		task->task.flags |= sHT_TASK_SCHEDULE;
		return 0;
	case sHT_SEND_KERNEL_OOM:
		/* No, I don't like overcommiting.
		   Killing is better than hanging, though. */
		worker->flags |= sHT_WORKER_OOM;
		task->task.flags |= sHT_TASK_SCHEDULE;
		return 0;
	case sHT_SEND_OTHER:
		sHT_todo("didn't recognise TCP error");
	default:
		sHT_assert(0);
	}
}

void
sHT_socket_sendsome_tcp(struct sHT_worker *worker, struct sHT_task_stream *task)
{
	const unsigned char *buf = task->stream.to_write.first;
	size_t start = task->stream.to_write.offset;
	size_t end = (task->stream.to_write.offset + task->stream.to_write.length) % sHT_PAPER_SIZE;
	/* TODO: do this branchless (feasible on x86, <sHT/minmax.h>) */
	if (sHT_gt(start, end))
		end = sHT_PAPER_SIZE;
	end = sHT_index_nospec(end, sHT_PAPER_SIZE - start);
	do {
		/* XXX: speculatively negative sizes? */
		ssize_t sent = send(task->stream.fd, buf + start, end - start, MSG_DONTWAIT | MSG_NOSIGNAL);
		if (sHT_lt0(sent))
			continue;
		/* some data is on the kernel queue, or the NIC ... */
		sHT_assert(sent <= task->stream.to_write.length);
		task->stream.to_write.offset = (task->stream.to_write.offset + sent) % sHT_PAPER_SIZE;
		task->stream.to_write.length -= sent;
		return;
		/* TODO intrusive Spectre mitigations*/
	} while (sHT_unlikely(sHT_socket_sendrecv_errno(worker, task, errno, EPOLLOUT)));
}

void
sHT_socket_readsome_tcp(struct sHT_worker *worker, struct sHT_task_stream *task)
{
	unsigned char *buf = task->stream.has_read.first;
	size_t start = (task->stream.has_read.offset + task->stream.has_read.length) % sHT_PAPER_SIZE;
	size_t end = task->stream.has_read.offset;
	if (sHT_gt(start, end))
		end = sHT_PAPER_SIZE;
	/* XXX: this doesn't seem correct */
	end %= sHT_PAPER_SIZE - start;
	/* XXX: speculatively negative sizes? */
	do {
		ssize_t received;
		received = recv(task->stream.fd, buf + start, end - start, MSG_DONTWAIT);
		if (sHT_lt0(received))
			continue;
		sHT_assert(received <= sHT_PAPER_SIZE - task->stream.has_read.length);
		task->stream.has_read.length += received;
		return;
		/* TODO intrusive Spectre mitigations*/
	} while (sHT_unlikely(sHT_socket_sendrecv_errno(worker, task, errno, EPOLLOUT)));
}
